"""initialize tables

Revision ID: 1dd9fa5b38ff
Revises: 
Create Date: 2024-07-24 11:13:29.124449

"""
from typing import Sequence, Union
from pathlib import Path

from alembic import op, context
import sqlalchemy as sa
import sqlmodel
import gpustack
import logging
from gpustack import __version__

logger = logging.getLogger(__name__)

# revision identifiers, used by Alembic.
revision: str = '1dd9fa5b38ff'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
    # ### commands auto generated by Alembic - please adjust! ###
    op.create_table('api_keys',
    sa.Column('deleted_at', sa.DateTime(), nullable=True),
    sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
    sa.Column('description', sa.Text(), nullable=True),
    sa.Column('id', sa.Integer(), nullable=False),
    sa.Column('access_key', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
    sa.Column('hashed_secret_key', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
    sa.Column('user_id', sa.Integer(), nullable=False),
    sa.Column('expires_at', gpustack.schemas.common.UTCDateTime(), nullable=True),
    sa.Column('created_at', gpustack.schemas.common.UTCDateTime(), nullable=False),
    sa.Column('updated_at', gpustack.schemas.common.UTCDateTime(), nullable=False),
    sa.PrimaryKeyConstraint('id'),
    sa.UniqueConstraint('hashed_secret_key'),
    sa.UniqueConstraint('name', 'user_id', name='uix_name_user_id')
    )
    op.create_index(op.f('ix_api_keys_access_key'), 'api_keys', ['access_key'], unique=True)
    op.create_table('models',
    sa.Column('deleted_at', sa.DateTime(), nullable=True),
    sa.Column('source', sa.Enum('HUGGING_FACE', 'OLLAMA_LIBRARY', 'MODEL_SCOPE', 'LOCAL_PATH', name='sourceenum'), nullable=False),
    sa.Column('huggingface_repo_id', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
    sa.Column('huggingface_filename', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
    sa.Column('ollama_library_model_name', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
    sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
    sa.Column('description', sa.Text(), nullable=True),
    sa.Column('replicas', sa.Integer(), nullable=False),
    sa.Column('ready_replicas', sa.Integer(), nullable=False),
    sa.Column('id', sa.Integer(), nullable=False),
    sa.Column('created_at', gpustack.schemas.common.UTCDateTime(), nullable=False),
    sa.Column('updated_at', gpustack.schemas.common.UTCDateTime(), nullable=False),
    sa.PrimaryKeyConstraint('id')
    )
    op.create_index(op.f('ix_models_name'), 'models', ['name'], unique=True)
    op.create_table('system_loads',
    sa.Column('id', sa.Integer(), nullable=False),
    sa.Column('timestamp', sa.Integer(), nullable=False),
    sa.Column('cpu', sa.Float(), nullable=True),
    sa.Column('memory', sa.Float(), nullable=True),
    sa.Column('gpu', sa.Float(), nullable=True),
    sa.Column('gpu_memory', sa.Float(), nullable=True),
    sa.PrimaryKeyConstraint('id')
    )
    op.create_table('users',
    sa.Column('deleted_at', sa.DateTime(), nullable=True),
    sa.Column('username', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
    sa.Column('is_admin', sa.Boolean(), nullable=False),
    sa.Column('full_name', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
    sa.Column('require_password_change', sa.Boolean(), nullable=False),
    sa.Column('id', sa.Integer(), nullable=False),
    sa.Column('hashed_password', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
    sa.Column('created_at', gpustack.schemas.common.UTCDateTime(), nullable=False),
    sa.Column('updated_at', gpustack.schemas.common.UTCDateTime(), nullable=False),
    sa.PrimaryKeyConstraint('id')
    )
    op.create_table('workers',
    sa.Column('deleted_at', sa.DateTime(), nullable=True),
    sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
    sa.Column('hostname', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
    sa.Column('ip', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
    sa.Column('labels', sa.JSON(), nullable=True),
    sa.Column('system_reserved', gpustack.schemas.common.JSON(), nullable=True),
    sa.Column('state', sa.Enum('NOT_READY', 'READY', name='workerstateenum'), nullable=False),
    sa.Column('state_message', sa.Text(), nullable=True),
    sa.Column('status', gpustack.schemas.common.JSON(), nullable=True),
    sa.Column('id', sa.Integer(), nullable=False),
    sa.Column('created_at', gpustack.schemas.common.UTCDateTime(), nullable=False),
    sa.Column('updated_at', gpustack.schemas.common.UTCDateTime(), nullable=False),
    sa.PrimaryKeyConstraint('id')
    )
    op.create_index(op.f('ix_workers_name'), 'workers', ['name'], unique=True)
    op.create_table('model_instances',
    sa.Column('deleted_at', sa.DateTime(), nullable=True),
    sa.Column('source', sa.Enum('HUGGING_FACE', 'OLLAMA_LIBRARY', 'MODEL_SCOPE', 'LOCAL_PATH', name='sourceenum'), nullable=False),
    sa.Column('huggingface_repo_id', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
    sa.Column('huggingface_filename', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
    sa.Column('ollama_library_model_name', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
    sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
    sa.Column('worker_id', sa.Integer(), nullable=True),
    sa.Column('worker_name', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
    sa.Column('worker_ip', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
    sa.Column('pid', sa.Integer(), nullable=True),
    sa.Column('port', sa.Integer(), nullable=True),
    sa.Column('download_progress', sa.Float(), nullable=True),
    sa.Column('state', sa.Enum('INITIALIZING', 'PENDING', 'RUNNING', 'SCHEDULED', 'ERROR', 'DOWNLOADING', 'ANALYZING', name='modelinstancestateenum'), nullable=False),
    sa.Column('state_message', sa.Text(), nullable=True),
    sa.Column('computed_resource_claim', gpustack.schemas.common.JSON(), nullable=True),
    sa.Column('gpu_index', sa.Integer(), nullable=True),
    sa.Column('model_id', sa.Integer(), nullable=False),
    sa.Column('model_name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
    sa.Column('id', sa.Integer(), nullable=False),
    sa.Column('created_at', gpustack.schemas.common.UTCDateTime(), nullable=False),
    sa.Column('updated_at', gpustack.schemas.common.UTCDateTime(), nullable=False),
    sa.ForeignKeyConstraint(['model_id'], ['models.id'], ),
    sa.PrimaryKeyConstraint('id')
    )
    op.create_index(op.f('ix_model_instances_name'), 'model_instances', ['name'], unique=True)
    op.create_table('model_usages',
    sa.Column('id', sa.Integer(), nullable=False),
    sa.Column('user_id', sa.Integer(), nullable=False),
    sa.Column('model_id', sa.Integer(), nullable=False),
    sa.Column('date', sa.Date(), nullable=False),
    sa.Column('prompt_token_count', sa.Integer(), nullable=False),
    sa.Column('completion_token_count', sa.Integer(), nullable=False),
    sa.Column('request_count', sa.Integer(), nullable=False),
    sa.Column('operation', sa.Enum('CHAT_COMPLETION', 'COMPLETION', 'EMBEDDING', 'RERANK','IMAGE_GENERATION', name='operationenum'), nullable=False),
    sa.ForeignKeyConstraint(['model_id'], ['models.id'], name='fk_model_usages_model_id_models'),
    sa.ForeignKeyConstraint(['user_id'], ['users.id'], name='fk_model_usages_user_id_users'),
    sa.PrimaryKeyConstraint('id')
    )
    if context.get_context().config.get_main_option("called_by_db_migration") != "true":
        write_bootstrap_version()
    # ### end Alembic commands ###


def write_bootstrap_version():
    """Writes the bootstrap_version file to the data directory."""
    gpustack_config = gpustack.config.config.get_global_config()
    data_dir = gpustack_config.data_dir if gpustack_config else "/var/lib/gpustack"

    # As of v2.0.1, writes a bootstrap_version file to the data directory.
    # We can make compatibility decisions based on this file in the future.
    try:
        version_file = Path(data_dir) / "bootstrap_version"
        version_file.parent.mkdir(parents=True, exist_ok=True)
        version_file.write_text(__version__)
        logger.debug(f"Created bootstrap_version file at {version_file}")
    except FileExistsError:
        pass
    except Exception as e:
        logger.error(f"Error creating version file: {e}")

def downgrade() -> None:
    # ### commands auto generated by Alembic - please adjust! ###
    op.drop_table('model_usages')
    op.drop_index(op.f('ix_model_instances_name'), table_name='model_instances')
    op.drop_table('model_instances')
    op.drop_index(op.f('ix_workers_name'), table_name='workers')
    op.drop_table('workers')
    op.drop_table('users')
    op.drop_table('system_loads')
    op.drop_index(op.f('ix_models_name'), table_name='models')
    op.drop_table('models')
    op.drop_index(op.f('ix_api_keys_access_key'), table_name='api_keys')
    op.drop_table('api_keys')
    # ### end Alembic commands ###
